cmake_minimum_required(VERSION 3.18 FATAL_ERROR)

find_program(CCACHE ccache)
if(${CCACHE} STREQUAL "CCACHE-NOTFOUND")
  message(STATUS "Compile without ccache")
else()
  set(CMAKE_C_COMPILER_LAUNCHER ${CCACHE} CACHE PATH "cache Compiler")
  set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE} CACHE PATH "cache Compiler")
  message(STATUS "CMAKE_C_COMPILER_LAUNCHER:${CMAKE_C_COMPILER_LAUNCHER}")
  message(STATUS "CMAKE_CXX_COMPILER_LAUNCHER:${CMAKE_CXX_COMPILER_LAUNCHER}")
endif()

project(TORCHNPU CXX C)

set(LINUX TRUE)
set(CMAKE_INSTALL_MESSAGE NEVER)
# set(CMAKE_VERBOSE_MAKEFILE ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

if(DEFINED TORCH_VERSION)
  add_definitions(-DPYTORCH_NPU_VERSION="${TORCH_VERSION}")
endif()

set(PLUGIN_NAME torch_npu)

set(RPATH_VALUE $ORIGIN)
set(CMAKE_SKIP_BUILD_RPATH  FALSE)
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
set(CMAKE_INSTALL_RPATH "${RPATH_VALUE}/lib/:${RPATH_VALUE}/")
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${TORCHNPU_INSTALL_LIBDIR})
SET(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O2 -g")
SET(CMAKE_CXX_FLAGS_RELEASE "-O2")
SET(CMAKE_CXX_FLAGS_DEBUG "-O0 -g")

# LTO&PGO optimization in compile option
SET(IF_APPEND FALSE)
if ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
  SET(APPEND_FLAGS "-fGNU-compatibility -fuse-ld=lld -Wno-non-pod-varargs")
  if (DEFINED ENABLE_LTO)
    SET(IF_APPEND TRUE)
    SET(APPEND_FLAGS "${APPEND_FLAGS} -flto=thin")
  endif()
  if (DEFINED PGO_MODE)
    SET(IF_APPEND TRUE)
    if (PGO_MODE EQUAL 1)
      SET(APPEND_FLAGS "${APPEND_FLAGS} -fprofile-generate")
    elseif (PGO_MODE EQUAL 2)
      SET(APPEND_FLAGS "${APPEND_FLAGS} -fprofile-use=${CMAKE_CURRENT_SOURCE_DIR}/default.profdata")
    endif()
  endif()
else()
  if (DEFINED ENABLE_LTO)
    message(FATAL_ERROR "Currently, LTO auto build is not supported in ${CMAKE_CXX_COMPILER_ID}")
  endif()
  if (DEFINED PGO_MODE)
    message(FATAL_ERROR "Currently, PGO auto build is not supported in ${CMAKE_CXX_COMPILER_ID}")
  endif()
endif()
if (IF_APPEND)
  SET(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} ${APPEND_FLAGS}")
endif()

# check and set CMAKE_CXX_STANDARD
string(FIND "${CMAKE_CXX_FLAGS}" "-std=c++" env_cxx_standard)
if(env_cxx_standard GREATER -1)
  message(
      WARNING "C++ standard version definition detected in environment variable."
      "PyTorch requires -std=c++17. Please remove -std=c++ settings in your environment.")
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_EXTENSIONS OFF)

set(TORCHNPU_ROOT "${PROJECT_SOURCE_DIR}/torch_npu/csrc")
set(TORCHNPU_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party")

set(Torch_DIR ${PYTORCH_INSTALL_DIR}/share/cmake/Torch)
FIND_PACKAGE(Torch REQUIRED)

set(LINUX TRUE)
set(CMAKE_INSTALL_MESSAGE NEVER)
#set(CMAKE_VERBOSE_MAKEFILE ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

# Define build type
IF(CMAKE_BUILD_TYPE MATCHES Debug)
  message("Debug build.")
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_DEBUG")
ELSEIF(CMAKE_BUILD_TYPE MATCHES RelWithDebInfo)
  message("RelWithDebInfo build")
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNDEBUG")
ELSE()
  message("Release build.")
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNDEBUG")
ENDIF()

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing")
# Eigen fails to build with some versions, so convert this to a warning
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wextra")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-missing-field-initializers")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-type-limits")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-array-bounds")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unknown-pragmas")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-compare")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-variable")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-function")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-result")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-strict-overflow")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-strict-aliasing")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=deprecated-declarations")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-ignored-qualifiers")
if (CMAKE_COMPILER_IS_GNUCXX AND NOT (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0.0))
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-stringop-overflow")
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pedantic")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=redundant-decls")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=old-style-cast")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__FILENAME__='\"$(notdir $(abspath $<))\"'")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-builtin-macro-redefined -D__FILE__='\"$(subst $(realpath ${CMAKE_SOURCE_DIR})/,,$(abspath $<))\"'")

# These flags are not available in GCC-4.8.5. Set only when using clang.
if ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-invalid-partial-specialization")
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-typedef-redefinition")
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unknown-warning-option")
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-private-field")
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-inconsistent-missing-override")
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-aligned-allocation-unavailable")
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-c++17-extensions")
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-constexpr-not-const")
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-missing-braces")
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Qunused-arguments")
  if (${COLORIZE_OUTPUT})
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fcolor-diagnostics")
  endif()
endif()
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9)
  if (${COLORIZE_OUTPUT})
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fdiagnostics-color=always")
  endif()
endif()
if ((APPLE AND (NOT ("${CLANG_VERSION_STRING}" VERSION_LESS "9.0")))
  OR (CMAKE_COMPILER_IS_GNUCXX
  AND (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0 AND NOT APPLE)))
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -faligned-new")
endif()
if (WERROR)
  check_cxx_compiler_flag("-Werror" COMPILER_SUPPORT_WERROR)
  if (NOT COMPILER_SUPPORT_WERROR)
    set(WERROR FALSE)
  else()
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror")
  endif()
endif(WERROR)
if (NOT APPLE)
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-but-set-variable")
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-uninitialized")
endif()
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fno-omit-frame-pointer -O0")
set(CMAKE_LINKER_FLAGS_DEBUG "${CMAKE_STATIC_LINKER_FLAGS_DEBUG} -fno-omit-frame-pointer -O0")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-math-errno")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-trapping-math")
set(CMAKE_CXX_COVERAGE $ENV{CALCULATE_CXX_COVERAGE})

if (CMAKE_BUILD_TYPE MATCHES Debug)
  set(CMAKE_C_FLAGS "-fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -fPIE -pie ${CMAKE_C_FLAGS}")
  set(CMAKE_CXX_FLAGS "-fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -fPIE -pie ${CMAKE_CXX_FLAGS}")
  set(CXXFLAGS "-fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -fPIE -pie ${CXXFLAGS}")
elseif (CMAKE_CXX_COVERAGE STREQUAL "1")
  set(CMAKE_C_FLAGS "-fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -fprofile-arcs -ftest-coverage -fPIE -pie ${CMAKE_C_FLAGS}")
  set(CMAKE_CXX_FLAGS "-fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -fprofile-arcs -ftest-coverage -fPIE -pie ${CMAKE_CXX_FLAGS}")
  set(CXXFLAGS "-fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -fprofile-arcs -ftest-coverage -fPIE -pie ${CXXFLAGS}")
else()
  set(CMAKE_C_FLAGS "-fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -fPIE -pie ${CMAKE_C_FLAGS}")
  set(CMAKE_CXX_FLAGS "-fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -fPIE -pie ${CMAKE_CXX_FLAGS}")
  set(CXXFLAGS "-fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -fPIE -pie ${CXXFLAGS}")
endif()

if (NOT DEFINED GLIBCXX_USE_CXX11_ABI)
  set(GLIBCXX_USE_CXX11_ABI 0)
endif()
message(STATUS "Determined _GLIBCXX_USE_CXX11_ABI=${GLIBCXX_USE_CXX11_ABI}")
set(_GLIBCXX_USE_CXX11_ABI ${GLIBCXX_USE_CXX11_ABI})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=${GLIBCXX_USE_CXX11_ABI}")
if (${GLIBCXX_USE_CXX11_ABI} EQUAL 0)
  set(CMAKE_CXX_FLAGS "-fabi-version=11 ${CMAKE_CXX_FLAGS}")
else()
  set(CXX_STANDARD_REQUIRED ON)
  if (DEFINED ABI_VERSION)
    set(CMAKE_CXX_FLAGS "-fabi-version=${ABI_VERSION} ${CMAKE_CXX_FLAGS}")
  endif()
endif()


if (DEFINED BUILD_LIBTORCH)
  add_compile_definitions(BUILD_LIBTORCH)
endif()


include_directories(${PROJECT_SOURCE_DIR})
include_directories(${PROJECT_SOURCE_DIR}/torch_npu/csrc/aten)
include_directories(${PROJECT_SOURCE_DIR}/third_party/hccl/inc)
include_directories(${PROJECT_SOURCE_DIR}/third_party/acl/inc)
include_directories(${PROJECT_SOURCE_DIR}/third_party/Tensorpipe)
include_directories(${PROJECT_SOURCE_DIR}/third_party/nlohmann/include)

# Set installed PyTorch dir
if(DEFINED PYTORCH_INSTALL_DIR)
  include_directories(${PYTORCH_INSTALL_DIR}/include)
  include_directories(${PYTORCH_INSTALL_DIR}/include/torch/csrc/api/include)
  include_directories(${PYTORCH_INSTALL_DIR}/include/torch/csrc/distributed)
else()
  message(FATAL_ERROR "Cannot find installed PyTorch directory")
endif()

# Set Python include dir
if(DEFINED PYTHON_INCLUDE_DIR)
  include_directories(${PYTHON_INCLUDE_DIR})
else()
  message(FATAL_ERROR "Cannot find installed Python head file directory")
endif()

# sources
set(ATEN_SRCS)
set(CORE_SRCS)
set(FRAMEWORK_SRCS)
set(LOGGING_SRCS)

if (NOT DEFINED BUILD_LIBTORCH)
  set(DIST_SRCS)
  set(FLOP_SRCS)
  set(NPU_SRCS)
  set(PROF_SRCS)
  set(IPC_SRCS)
  set(UTILS_SRCS)
  set(SAN_SRCS)
  set(AFD_SRCS)
endif()

if (DEFINED BUILD_LIBTORCH)
  set(NPU_CPP_LIBS_SRCS)
endif()

add_subdirectory(${TORCHNPU_ROOT}/aten)
add_subdirectory(${TORCHNPU_ROOT}/core)
add_subdirectory(${TORCHNPU_ROOT}/framework)
add_subdirectory(${TORCHNPU_ROOT}/flopcount)
add_subdirectory(${TORCHNPU_ROOT}/logging)
add_subdirectory(${TORCHNPU_ROOT}/custom_dtype)

if (NOT DEFINED BUILD_LIBTORCH)
  add_subdirectory(${TORCHNPU_ROOT}/distributed)
  add_subdirectory(${TORCHNPU_ROOT}/npu)
  add_subdirectory(${TORCHNPU_ROOT}/profiler)
  add_subdirectory(${TORCHNPU_ROOT}/ipc)
  add_subdirectory(${TORCHNPU_ROOT}/utils)
  add_subdirectory(${TORCHNPU_ROOT}/sanitizer)
  add_subdirectory(${TORCHNPU_ROOT}/afd)
endif()

if (DEFINED BUILD_LIBTORCH)
  add_subdirectory(${TORCHNPU_ROOT}/libs)
endif()

set(OPS_PLUGIN_SRCS)
# Add subdirectory of op-plugin
include_directories(${PROJECT_SOURCE_DIR}/third_party/op-plugin)
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/op-plugin/op_plugin)

if (DEFINED BUILD_TENSORPIPE)
  add_definitions(-DUSE_RPC_FRAMEWORK)
  set(BUILD_SHARED_LIBS ON)
  if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0")
    message(WARNING "tensorpipe forces CMake compatibility")
    set(CMAKE_POLICY_VERSION_MINIMUM 3.5)
  endif()
  add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/Tensorpipe)
  if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0")
    unset(CMAKE_POLICY_VERSION_MINIMUM)
  endif()
  set(BUILD_SHARED_LIBS OFF)
endif()

if (DEFINED BUILD_LIBTORCH)
  set(CPP_SRCS ${ATEN_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${FLOP_SRCS} ${CUS_DTYPE_SRCS} ${FRAMEWORK_SRCS} ${LOGGING_SRCS} ${NPU_CPP_LIBS_SRCS})
else()
# Compile code with pybind11
  set(CPP_SRCS ${ATEN_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${DIST_SRCS} ${FLOP_SRCS} ${CUS_DTYPE_SRCS} ${LOGGING_SRCS} ${FRAMEWORK_SRCS} ${NPU_SRCS} ${PROF_SRCS} ${IPC_SRCS} ${UTILS_SRCS} ${SAN_SRCS} ${AFD_SRCS})
endif()

add_library(${PLUGIN_NAME} SHARED ${CPP_SRCS})
include(CheckCXXCompilerFlag)
check_cxx_compiler_flag("-fvisibility=hidden" COMPILER_SUPPORTS_HIDDEN_VISIBILITY)
if(${COMPILER_SUPPORTS_HIDDEN_VISIBILITY})
  target_compile_options(${PLUGIN_NAME} PRIVATE "-fvisibility=hidden")
endif()

if (DEFINED BUILD_TORCHAIR)
  add_subdirectory(${TORCHNPU_THIRD_PARTY_ROOT}/torchair)
  add_dependencies(${PLUGIN_NAME} copy_torchair_pyfiles)
endif()

add_subdirectory(${TORCHNPU_THIRD_PARTY_ROOT}/fmt EXCLUDE_FROM_ALL)

link_directories(${PYTORCH_INSTALL_DIR}/lib)
link_directories(${TORCHNPU_THIRD_PARTY_ROOT}/acl/libs)

target_link_libraries(${PLUGIN_NAME} PUBLIC ${TORCHNPU_THIRD_PARTY_ROOT}/acl/libs/libhccl.so)

if (NOT DEFINED BUILD_LIBTORCH)
  target_link_libraries(${PLUGIN_NAME} PUBLIC ${PYTORCH_INSTALL_DIR}/lib/libtorch_python.so)
endif()

target_link_libraries(${PLUGIN_NAME} PUBLIC ${TORCHNPU_THIRD_PARTY_ROOT}/acl/libs/libascendcl.so)
target_link_libraries(${PLUGIN_NAME} PUBLIC ${TORCHNPU_THIRD_PARTY_ROOT}/acl/libs/libacl_op_compiler.so)
target_link_libraries(${PLUGIN_NAME} PUBLIC ${TORCHNPU_THIRD_PARTY_ROOT}/acl/libs/libge_runner.so)
target_link_libraries(${PLUGIN_NAME} PUBLIC ${TORCHNPU_THIRD_PARTY_ROOT}/acl/libs/libgraph.so)

if (DEFINED BUILD_TENSORPIPE)
  target_link_libraries(${PLUGIN_NAME} PUBLIC ${PROJECT_SOURCE_DIR}/build/packages/torch_npu/lib/libtensorpipe.so)
endif()

target_link_libraries(${PLUGIN_NAME} PUBLIC torch torch_cpu c10 fmt::fmt-header-only)

if (NOT DEFINED BUILD_LIBTORCH)
  set(ATEN_THREADING "OMP" CACHE STRING "ATen parallel backend")
  message(STATUS "Using ATen parallel backend: ${ATEN_THREADING}")
  if ("${ATEN_THREADING}" STREQUAL "OMP")
    target_compile_definitions(${PLUGIN_NAME} PUBLIC "-DAT_PARALLEL_OPENMP=1")
  elseif ("${ATEN_THREADING}" STREQUAL "NATIVE")
    target_compile_definitions(${PLUGIN_NAME} PUBLIC "-DAT_PARALLEL_NATIVE=1")
  elseif ("${ATEN_THREADING}" STREQUAL "TBB")
    target_compile_definitions(${PLUGIN_NAME} PUBLIC "-DAT_PARALLEL_NATIVE_TBB=1")
  else()
    message(FATAL_ERROR "Unknown ATen parallel backend: ${ATEN_THREADING}")
  endif()

  include(GNUInstallDirs)
  target_compile_options(${PLUGIN_NAME} PRIVATE "-DC10_BUILD_MAIN_LIB")
  install(TARGETS ${PLUGIN_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
endif()

if (NOT DEFINED BUILD_LIBTORCH)
  add_subdirectory(${TORCHNPU_ROOT}/toolkit)
  target_link_libraries(${PLUGIN_NAME} PUBLIC npu_profiler)
endif()

if (DEFINED BUILD_GTEST)
  enable_testing()
  SET(EXECUTABLE_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/build/gtest)
  add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/googletest)
  include_directories(${PROJECT_SOURCE_DIR}/third_party/googletest/googletest/include)

  set(TORCH_API_TEST_SOURCES)
  add_subdirectory(${PROJECT_SOURCE_DIR}/test/cpp/api)
  add_executable(test_api ${TORCH_API_TEST_SOURCES})

  target_link_libraries(test_api PUBLIC torch_npu)
  target_link_libraries(test_api PUBLIC gtest_main gtest)
endif()

if (DEFINED BUILD_LIBTORCH)
  configure_file(
    ${PROJECT_SOURCE_DIR}/cmake/Torch_npuConfig.cmake.in
    ${PROJECT_SOURCE_DIR}/build/Torch_npuConfig.cmake
    @ONLY)
endif()

